%% Setup
clear; clc;
rng(142);

n = 10;
r_true = 10;
r_model = 10;
lr = 1e-5;
num_iters = 1e6;
sigma = 0.01;
alpha = 1;
beta_list_dbgd = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1];
beta_list_penalty = [1,2,5,10,20,50,100,200,500,1000];

U_true = randn(n, r_true);
M_clean = U_true * U_true';
noise = randn(n); noise = (noise + noise') / 2 * sigma;
M = M_clean + noise;

[V, D] = eig(M);
[~, idx] = sort(diag(D), 'descend');
Lambda_r = diag(diag(D(idx(1:r_model), idx(1:r_model))));
Q_r = V(:, idx(1:r_model));
U_star = Q_r * sqrt(Lambda_r);
g_star = norm(M - U_star * U_star', 'fro')^2;

% function 1
% f_fun = @(U) sum(sqrt(U(:).^2 + alpha));
% grad_f_fun = @(U) U ./ sqrt(U.^2 + alpha);


% function 2
f_fun = @(U) sum(log(1 + U(:).^2 / alpha));
grad_f_fun = @(U) (2 * U) ./ (alpha + U.^2);

project = @(v, onto) ((sum(v(:) .* onto(:)) / sum(onto(:).^2)) * onto);

% Run experiments
methods = [string("DBGD_beta=") + beta_list_dbgd, string("Penalty_beta=") + beta_list_penalty];
results = struct();

for i = 1:length(methods)
    name = methods(i);
    beta = str2double(extractAfter(name, "="));
    results.(matlab.lang.makeValidName(name)) = run_log(name, beta, M, n, r_model, num_iters, lr, alpha, f_fun, grad_f_fun, project);
end

%% save data with other plots
metric_keys = {'g', 'f', 'grad_g_norm', 'grad_f_norm', 'proj_f', 'orth_f', 'cos_angle'};
titles = {'g(U)', 'f(U)', '\|\nabla g(U)\|', '\|\nabla f(U)\|', ...
          'Parallel part of \nabla f', 'Orthogonal part of \nabla f', ...
          'cos(angle(\nabla f, \nabla g))'};

methods_cell = fieldnames(results);

% for i = 1:length(metric_keys)
%     key = metric_keys{i};
%     figure; hold on; grid on;
%     for j = 1:length(methods_cell)
%         m = methods_cell{j};
%         vals = results.(m).(key);
%         plot(vals, 'DisplayName', strrep(m, '_', '\_'));
%     end
%     if ~strcmp(key, 'cos_angle')
%         set(gca, 'YScale', 'log');
%     end
%     title([titles{i} ' (Log-Smooth)'], 'Interpreter', 'latex');
%     xlabel('Iteration');
%     legend('Interpreter', 'latex', 'FontSize', 8);
%     set(gca, 'FontSize', 12);
%     hold off;
% end


% %% Final (g, f) SCATTER plot
% figure('Color', 'w'); hold on; grid on;
% 
% % Dummy plots for legend
% h_dbgd = plot(nan, nan, 'o', 'MarkerFaceColor', 'b', 'MarkerEdgeColor', 'k', ...
%               'MarkerSize', 8, 'DisplayName', 'DBGD');
% h_penalty = plot(nan, nan, 's', 'MarkerFaceColor', 'r', 'MarkerEdgeColor', 'k', ...
%                  'MarkerSize', 8, 'DisplayName', 'Penalty');
% 
% % Actual data points
% for j = 1:length(methods_cell)
%     m = methods_cell{j};
%     gval = results.(m).final_g;
%     fval = results.(m).final_f;
% 
%     if contains(m, 'Penalty')
%         color = 'r'; marker = 's';
%     else
%         color = 'b'; marker = 'o';
%     end
% 
%     scatter(gval, fval, 100, color, marker, 'filled', 'MarkerEdgeColor', 'k');
% end
% 
% set(gca, 'XScale', 'log', 'YScale', 'log', 'FontSize', 14);
% xlabel('$g(U)$', 'Interpreter', 'latex', 'FontSize', 16);
% ylabel('$f(U)$', 'Interpreter', 'latex', 'FontSize', 16);
% legend([h_dbgd, h_penalty], 'Location', 'southeast', 'Box', 'off', 'FontSize', 12);
% axis tight;
% box on;

%% Final (g, f) SCATTER plot
figure('Color', 'w'); hold on; grid on;

% Dummy plots for legend
h_dbgd = plot(nan, nan, 'o', 'MarkerFaceColor', 'b', 'MarkerEdgeColor', 'k', ...
              'MarkerSize', 10, 'DisplayName', 'DBGD');
h_penalty = plot(nan, nan, 's', 'MarkerFaceColor', 'r', 'MarkerEdgeColor', 'k', ...
                 'MarkerSize', 10, 'DisplayName', 'Penalty');

% Actual data points
for j = 1:length(methods_cell)
    m = methods_cell{j};
    gval = results.(m).final_g;
    fval = results.(m).final_f;

    if contains(m, 'Penalty')
        color = 'r'; marker = 's';
    else
        color = 'b'; marker = 'o';
    end

    scatter(gval, fval, 120, color, marker, 'filled', 'MarkerEdgeColor', 'k');
end

set(gca, 'XScale', 'log', 'YScale', 'log', 'FontSize', 18);
xlabel('$g(U)$', 'Interpreter', 'latex', 'FontSize', 22, 'FontWeight', 'bold');
ylabel('$f(U)$', 'Interpreter', 'latex', 'FontSize', 22, 'FontWeight', 'bold');
legend([h_dbgd, h_penalty], 'Location', 'southeast', ...
       'Interpreter', 'latex', 'FontSize', 18, 'Box', 'off');
% title('Final Objective Values', 'Interpreter', 'latex', 'FontSize', 20);
axis tight;
box on;



%% Final (||∇g||, ||orth ∇f||) SCATTER plot
figure('Color', 'w'); hold on; grid on;

h_dbgd = plot(nan, nan, 'o', 'MarkerFaceColor', 'b', 'MarkerEdgeColor', 'k', ...
              'MarkerSize', 10, 'DisplayName', 'DBGD');
h_penalty = plot(nan, nan, 's', 'MarkerFaceColor', 'r', 'MarkerEdgeColor', 'k', ...
                 'MarkerSize', 10, 'DisplayName', 'Penalty');

for j = 1:length(methods_cell)
    m = methods_cell{j};
    gnorm = results.(m).grad_g_norm(end);
    orthf = results.(m).orth_f(end);

    if contains(m, 'Penalty')
        color = 'r'; marker = 's';
    else
        color = 'b'; marker = 'o';
    end

    scatter(gnorm, orthf, 120, color, marker, 'filled', 'MarkerEdgeColor', 'k');
end

set(gca, 'XScale', 'log', 'YScale', 'log', 'FontSize', 18);
xlabel('$\|\nabla g(U)\|$', 'Interpreter', 'latex', 'FontSize', 22, 'FontWeight', 'bold');
ylabel('$\|\nabla_{\perp} f(U)\|$', 'Interpreter', 'latex', 'FontSize', 22, 'FontWeight', 'bold');
legend([h_dbgd, h_penalty], 'Location', 'southeast', ...
       'Interpreter', 'latex', 'FontSize', 18, 'Box', 'off');
% title('Final Stationarity Metrics', 'Interpreter', 'latex', 'FontSize', 20);
axis tight;
box on;

%% function
function result = run_log(method, beta, M, n, r_model, num_iters, lr, alpha, f_fun, grad_f_fun, project)
    U = randn(n, r_model) * 0.1;
    g_vals = zeros(1, num_iters);
    f_vals = zeros(1, num_iters);
    grad_g_norms = zeros(1, num_iters);
    grad_f_norms = zeros(1, num_iters);
    proj_norms = zeros(1, num_iters);
    ortho_norms = zeros(1, num_iters);
    cos_angles = zeros(1, num_iters);
    lam_trace = zeros(1, num_iters);

    for i = 1:num_iters
        M_hat = U * U';
        R = M - M_hat;
        grad_g = -4 * R * U;
        grad_f_val = grad_f_fun(U);

        g_vals(i) = norm(R, 'fro')^2;
        f_vals(i) = f_fun(U);
        grad_g_norms(i) = norm(grad_g, 'fro');
        grad_f_norms(i) = norm(grad_f_val, 'fro');

        proj = project(grad_f_val, grad_g);
        proj_norms(i) = norm(proj, 'fro');
        ortho = grad_f_val - proj;
        ortho_norms(i) = norm(ortho, 'fro');

        dot_val = sum(grad_f_val(:) .* grad_g(:));
        norm_prod = norm(grad_f_val(:)) * norm(grad_g(:));
        cos_theta = 0;
        if norm_prod > 0
            cos_theta = dot_val / norm_prod;
        end
        cos_angles(i) = cos_theta;

        if startsWith(method, 'DBGD')
            phi = beta * sum(grad_g(:).^2);
            inner = sum(grad_f_val(:) .* grad_g(:));
            lam = max((phi - inner) / sum(grad_g(:).^2), 0);
            grad = (grad_f_val + lam * grad_g);
            lam_trace(i) = lam;
        elseif startsWith(method, 'Penalty')
            grad = (beta * grad_g + grad_f_val)/(1+beta);
        end
        U = U - lr * grad;
    end

    result.g = g_vals;
    result.f = f_vals;
    result.grad_g_norm = grad_g_norms;
    result.grad_f_norm = grad_f_norms;
    result.proj_f = proj_norms;
    result.orth_f = ortho_norms;
    result.cos_angle = cos_angles;
    result.lambda_t = lam_trace;
    result.final_g = g_vals(end);
    result.final_f = f_vals(end);
end
